# =============================================================================
# Statistical Analysis Code for:
# "Chronic adaptive versus conventional DBS response patterns in Parkinson's
#  disease: A pilot randomized crossover trial"
#
# Leave-One-Out (LOO) Sensitivity Analysis for Treatment Effect Modifiers
# Author: Jun Tanimura, MD, MSc.
# Created: 2025-11-01
# Last Updated: 2025-11-25
#
# This script performs LOO sensitivity analysis for all outcome × baseline
# modifier combinations to assess the robustness of interaction estimates
# in a small sample (n=9 patients).
# Based on: 03_ANCOVA_EXP_v3.3_20250907.R
# =============================================================================

library(tidyverse)
library(lme4)
library(lmerTest)

# =============================================================================
# Load and prepare data (same as main analysis)
# =============================================================================

# Configuration (same as main analysis)
eval_items_ordered <- c("ADL", "UPDRS_part_1", "UPDRS_part_2", "UPDRS_part_3", "UPDRS_part_4",
                        "UPDRS_total", "ON_time", "OFF_time",
                        "Dys_sev_time", "Dys_weak_time", "MMSE", "PDQ_39")

lower_the_better_metrics <- c("UPDRS_part_1", "UPDRS_part_2", "UPDRS_part_3", "UPDRS_part_4", 
                              "UPDRS_total", "PDQ_39", "Dys_sev_time", 
                              "Dys_weak_time", "OFF_time", 
                              "age", "PD_duration", "P_HY", "LEDD")

higher_the_better_metrics <- c("MMSE", "ON_time")

baseline_predictors <- c("age", "PD_duration", "sex", "LEDD",
                         "P_ADL", "P_MMSE", "P_HY",
                         "P_UPDRS_part_1", "P_UPDRS_part_2", "P_UPDRS_part_3", 
                         "P_UPDRS_part_4", "P_UPDRS_total",
                         "P_Mean_corr_ON_time", "P_Mean_corr_OFF_time", 
                         "P_Mean_corr_Dys_sev_time", "P_Mean_corr_Dys_weak_time",
                         "P_PDQ_39")

predictor_better_when_higher <- c("P_ADL", "P_MMSE", "P_Mean_corr_ON_time")

predictor_worse_when_higher <- c("age", "PD_duration", "LEDD",
                                 "P_UPDRS_part_1", "P_UPDRS_part_2", "P_UPDRS_part_3", 
                                 "P_UPDRS_part_4", "P_UPDRS_total", "P_HY", "P_PDQ_39", 
                                 "P_Mean_corr_OFF_time", "P_Mean_corr_Dys_sev_time", 
                                 "P_Mean_corr_Dys_weak_time")

# Data Processing (same as main analysis)
# Try relative path first, then fallback
data_file <- if(file.exists("data.csv")) {
  "data.csv"
} else if(file.exists("01_data_20250824.csv")) {
  "01_data_20250824.csv"
} else {
  stop("Data file not found. Please check the path to data.csv")
}

df <- read_csv(data_file)
names(df) <- gsub("-", "_", names(df))

existing_predictors <- baseline_predictors[baseline_predictors %in% names(df)]
df_sign_unified <- df

# Apply unified sign convention
predictor_vars <- existing_predictors[!existing_predictors %in% c("sex")]

for(pred in predictor_vars) {
  if(pred %in% names(df_sign_unified)) {
    original_values <- df_sign_unified[[pred]]
    unified_values <- if(pred %in% predictor_better_when_higher) -original_values else original_values
    df_sign_unified[[paste0(pred, "_unified")]] <- unified_values
  }
}

# Z-score standardization
for(pred in predictor_vars) {
  unified_col <- paste0(pred, "_unified")
  if(unified_col %in% names(df_sign_unified)) {
    df_sign_unified[[paste0(pred, "_standardized")]] <- scale(df_sign_unified[[unified_col]])[,1]
  }
}

available_standardized_predictors <- paste0(predictor_vars, "_standardized")
if("sex" %in% existing_predictors) {
  available_standardized_predictors <- c(available_standardized_predictors, "sex")
}

# Transform to long format (same as main analysis)
ancova_data <- map_dfr(eval_items_ordered, function(item) {
  col_P <- paste0("P_", item)
  col_C <- paste0("C_", item)
  col_A <- paste0("A_", item)
  
  # Handle special column name mappings
  if(item == "PDQ_39") {
    col_P <- "P_PDQ_39"
    col_C <- "C_PDQ_39"
    col_A <- "A_PDQ_39"
  } else if(item == "ON_time") {
    col_P <- "P_Mean_corr_ON_time"
    col_C <- "C_Mean_corr_ON_time"
    col_A <- "A_Mean_corr_ON_time"
  } else if(item == "OFF_time") {
    col_P <- "P_Mean_corr_OFF_time"
    col_C <- "C_Mean_corr_OFF_time"
    col_A <- "A_Mean_corr_OFF_time"
  } else if(item == "Dys_sev_time") {
    col_P <- "P_Mean_corr_Dys_sev_time"
    col_C <- "C_Mean_corr_Dys_sev_time"
    col_A <- "A_Mean_corr_Dys_sev_time"
  } else if(item == "Dys_weak_time") {
    col_P <- "P_Mean_corr_Dys_weak_time"
    col_C <- "C_Mean_corr_Dys_weak_time"
    col_A <- "A_Mean_corr_Dys_weak_time"
  }
  
  if(all(c(col_P, col_C, col_A) %in% names(df_sign_unified))) {
    df_sign_unified %>%
      select(PatientID, Group, all_of(c(col_P, col_C, col_A)), 
             all_of(available_standardized_predictors)) %>%
      pivot_longer(
        cols = c(all_of(col_C), all_of(col_A)),
        names_to = "Treatment_raw",
        values_to = "Score_post"
      ) %>%
      mutate(
        Baseline = !!sym(col_P),
        Treatment = case_when(
          str_detect(Treatment_raw, "^C_") ~ "cDBS",
          str_detect(Treatment_raw, "^A_") ~ "aDBS"
        ),
        Period = case_when(
          Group == 1 & Treatment == "cDBS" ~ "Phase1",
          Group == 1 & Treatment == "aDBS" ~ "Phase2", 
          Group == 2 & Treatment == "aDBS" ~ "Phase1",
          Group == 2 & Treatment == "cDBS" ~ "Phase2"
        ),
        Evaluation = item,  # Keep original item name (ON_time, OFF_time, etc.)
        is_time_var = str_detect(item, "time"),
        Score_transformed = ifelse(is_time_var, log1p(pmax(Score_post, 0)), Score_post),
        Baseline_transformed = ifelse(is_time_var, log1p(pmax(Baseline, 0)), Baseline)
      ) %>%
      select(PatientID, Group, Evaluation, Treatment, Period,
             Baseline, Score_post, Baseline_transformed, Score_transformed,
             all_of(available_standardized_predictors))
  }
}) %>%
  filter(!is.na(Score_post)) %>%
  mutate(
    Treatment = factor(Treatment, levels = c("cDBS", "aDBS")),
    Period = factor(Period, levels = c("Phase1", "Phase2")),
    Evaluation = factor(Evaluation, levels = eval_items_ordered),
    PatientID = factor(PatientID)
  )

ancova_data <- ancova_data %>%
  group_by(Evaluation) %>%
  mutate(Baseline_z = as.numeric(scale(Baseline_transformed))) %>%
  ungroup()

if("sex" %in% existing_predictors) {
  ancova_data <- ancova_data %>% mutate(sex = factor(sex))
}

# =============================================================================
# Functions for modifier analysis with LOO
# =============================================================================

# Original modifier analysis function (for full dataset)
run_modifier_analysis_full <- function(data, metric, modifier_var) {
  metric_data <- data %>% 
    filter(Evaluation == metric) %>%
    filter(!is.na(Score_transformed) & !is.na(Baseline_z) & !is.na(!!sym(modifier_var)))
  
  if(nrow(metric_data) < 10) return(NULL)
  
  metric_data <- metric_data %>% mutate(Sequence = factor(Group))
  
  modifier_formula <- paste("Score_transformed ~ Treatment * ", modifier_var,
                            "+ Period + Sequence + Baseline_z + (1|PatientID)")
  
  modifier_model <- tryCatch(
    lmer(as.formula(modifier_formula), data = metric_data, REML = TRUE,
         control = lmerControl(optimizer = "bobyqa")),
    error = function(e) NULL
  )
  
  if(is.null(modifier_model)) return(NULL)
  
  model_summary <- summary(modifier_model)
  interaction_coef_name <- paste0("TreatmentaDBS:", modifier_var)
  
  if(!interaction_coef_name %in% rownames(model_summary$coefficients)) return(NULL)
  
  interaction_effect <- model_summary$coefficients[interaction_coef_name, ]
  
  # Extract t-statistic and calculate partial r
  t_val <- interaction_effect[4]
  df_val <- interaction_effect[3]
  
  if(is.finite(t_val) && is.finite(df_val) && df_val > 0) {
    r_raw <- t_val / sqrt(t_val^2 + df_val)
    
    # Apply outcome sign multiplier
    outcome_sign_multiplier <- if(metric %in% lower_the_better_metrics) -1 else 1
    partial_r_unified <- outcome_sign_multiplier * r_raw
  } else {
    partial_r_unified <- NA
  }
  
  list(
    r_full = partial_r_unified,
    n_patients = length(unique(metric_data$PatientID)),
    n_obs = nrow(metric_data)
  )
}

# LOO modifier analysis function
run_modifier_analysis_loo <- function(data, metric, modifier_var, excluded_patient_id) {
  metric_data <- data %>% 
    filter(Evaluation == metric) %>%
    filter(!is.na(Score_transformed) & !is.na(Baseline_z) & !is.na(!!sym(modifier_var))) %>%
    filter(PatientID != excluded_patient_id)  # Exclude one patient
  
  if(nrow(metric_data) < 8) return(NULL)  # Need at least 8 patients after exclusion
  
  metric_data <- metric_data %>% mutate(Sequence = factor(Group))
  
  modifier_formula <- paste("Score_transformed ~ Treatment * ", modifier_var,
                            "+ Period + Sequence + Baseline_z + (1|PatientID)")
  
  modifier_model <- tryCatch(
    lmer(as.formula(modifier_formula), data = metric_data, REML = TRUE,
         control = lmerControl(optimizer = "bobyqa")),
    error = function(e) NULL
  )
  
  if(is.null(modifier_model)) return(NULL)
  
  model_summary <- summary(modifier_model)
  interaction_coef_name <- paste0("TreatmentaDBS:", modifier_var)
  
  if(!interaction_coef_name %in% rownames(model_summary$coefficients)) return(NULL)
  
  interaction_effect <- model_summary$coefficients[interaction_coef_name, ]
  
  # Extract t-statistic and calculate partial r
  t_val <- interaction_effect[4]
  df_val <- interaction_effect[3]
  
  if(is.finite(t_val) && is.finite(df_val) && df_val > 0) {
    r_raw <- t_val / sqrt(t_val^2 + df_val)
    
    # Apply outcome sign multiplier
    outcome_sign_multiplier <- if(metric %in% lower_the_better_metrics) -1 else 1
    partial_r_unified <- outcome_sign_multiplier * r_raw
  } else {
    partial_r_unified <- NA
  }
  
  partial_r_unified
}

# =============================================================================
# Main LOO sensitivity analysis
# =============================================================================

cat("\n=== LOO Sensitivity Analysis ===\n")
cat("Starting Leave-One-Out analysis for all modifier-outcome combinations...\n\n")

# Get modifier variables
modifier_vars <- c()
standardized_predictor_names <- paste0(predictor_vars, "_standardized")
for(pred_std in standardized_predictor_names) {
  if(pred_std %in% names(ancova_data)) {
    modifier_vars <- c(modifier_vars, pred_std)
  }
}
if("sex" %in% existing_predictors) {
  modifier_vars <- c(modifier_vars, "sex")
}

# Get all unique patient IDs
all_patients <- unique(ancova_data$PatientID)
n_patients <- length(all_patients)

cat(sprintf("Total patients: %d\n", n_patients))
cat(sprintf("Total combinations to analyze: %d (outcomes) × %d (modifiers) = %d\n\n",
            length(eval_items_ordered), length(modifier_vars),
            length(eval_items_ordered) * length(modifier_vars)))

# Initialize results storage
loo_results <- list()
combination_counter <- 0
total_combinations <- length(eval_items_ordered) * length(modifier_vars)

# Run LOO analysis for each combination
for(metric in eval_items_ordered) {
  for(modifier in modifier_vars) {
    combination_counter <- combination_counter + 1
    
    # Check if combination is valid
    metric_data_check <- ancova_data %>% 
      filter(Evaluation == metric) %>%
      filter(!is.na(Score_transformed) & !is.na(Baseline_z) & !is.na(!!sym(modifier)))
    
    if(nrow(metric_data_check) < 10) {
      cat(sprintf("[%d/%d] Skipping %s × %s (insufficient data)\n",
                  combination_counter, total_combinations, metric, modifier))
      next
    }
    
    cat(sprintf("[%d/%d] Processing %s × %s...", 
                combination_counter, total_combinations, metric, modifier))
    
    # Get full dataset estimate
    full_result <- run_modifier_analysis_full(ancova_data, metric, modifier)
    
    if(is.null(full_result)) {
      cat(" failed (full model)\n")
      next
    }
    
    # Run LOO for each patient
    loo_r_values <- numeric(n_patients)
    loo_success_count <- 0
    
    for(i in seq_along(all_patients)) {
      patient_id <- all_patients[i]
      loo_r <- run_modifier_analysis_loo(ancova_data, metric, modifier, patient_id)
      
      if(!is.null(loo_r) && !is.na(loo_r)) {
        loo_r_values[i] <- loo_r
        loo_success_count <- loo_success_count + 1
      } else {
        loo_r_values[i] <- NA
      }
    }
    
    # Calculate summary statistics
    loo_r_valid <- loo_r_values[!is.na(loo_r_values)]
    
    if(length(loo_r_valid) == 0) {
      cat(" failed (no valid LOO estimates)\n")
      next
    }
    
    loo_min <- min(loo_r_valid)
    loo_max <- max(loo_r_valid)
    loo_range <- loo_max - loo_min
    sign_change <- (loo_min < 0 & loo_max > 0) | (loo_min > 0 & loo_max < 0)
    robust_status <- ifelse(sign_change, "Non-robust", "Robust")
    
    # Store results
    combination_key <- paste(metric, modifier, sep = "_")
    loo_results[[combination_key]] <- list(
      Outcome = metric,
      Baseline = modifier,
      r_full = full_result$r_full,
      n_patients_full = full_result$n_patients,
      n_obs_full = full_result$n_obs,
      loo_r_values = loo_r_values,
      loo_min = loo_min,
      loo_max = loo_max,
      loo_range = loo_range,
      sign_change = sign_change,
      robust_status = robust_status,
      n_loo_valid = length(loo_r_valid),
      n_loo_total = n_patients
    )
    
    cat(sprintf(" done (%d/%d LOO successful, %s)\n",
                loo_success_count, n_patients, robust_status))
  }
}

cat("\n=== Compiling Results ===\n")

# Compile results into data frame
loo_summary <- map_dfr(loo_results, function(result) {
  tibble(
    Outcome = result$Outcome,
    Baseline = result$Baseline,
    r_full = result$r_full,
    n_patients_full = result$n_patients_full,
    n_obs_full = result$n_obs_full,
    LOO_min = result$loo_min,
    LOO_max = result$loo_max,
    LOO_range = result$loo_range,
    Sign_change = result$sign_change,
    Robust_status = result$robust_status,
    n_LOO_valid = result$n_loo_valid,
    n_LOO_total = result$n_loo_total
  )
})

# Add detailed LOO values (one row per LOO iteration)
loo_detailed_list <- map(loo_results, function(result) {
  rows <- list()
  for(i in seq_along(result$loo_r_values)) {
    if(!is.na(result$loo_r_values[i])) {
      rows[[length(rows) + 1]] <- tibble(
        Outcome = result$Outcome,
        Baseline = result$Baseline,
        Excluded_patient = as.character(all_patients[i]),
        LOO_r = result$loo_r_values[i],
        r_full = result$r_full
      )
    }
  }
  if(length(rows) > 0) {
    bind_rows(rows)
  } else {
    NULL
  }
})

# Filter out NULL entries and combine
loo_detailed <- bind_rows(compact(loo_detailed_list)) %>%
  filter(!is.na(LOO_r))

# =============================================================================
# Summary statistics
# =============================================================================

cat("\n=== LOO Sensitivity Analysis Summary ===\n")
cat(sprintf("Total combinations analyzed: %d\n", nrow(loo_summary)))
cat(sprintf("Non-robust combinations (sign change): %d (%.1f%%)\n",
            sum(loo_summary$Sign_change, na.rm = TRUE),
            100 * mean(loo_summary$Sign_change, na.rm = TRUE)))
cat(sprintf("Robust combinations: %d (%.1f%%)\n",
            sum(!loo_summary$Sign_change, na.rm = TRUE),
            100 * mean(!loo_summary$Sign_change, na.rm = TRUE)))

# Summary by outcome
summary_by_outcome <- loo_summary %>%
  group_by(Outcome) %>%
  summarise(
    n_combinations = n(),
    n_non_robust = sum(Sign_change, na.rm = TRUE),
    pct_non_robust = 100 * mean(Sign_change, na.rm = TRUE),
    mean_range = mean(LOO_range, na.rm = TRUE),
    median_range = median(LOO_range, na.rm = TRUE),
    .groups = "drop"
  )

cat("\n=== Summary by Outcome ===\n")
print(summary_by_outcome)

# Summary by baseline modifier
summary_by_baseline <- loo_summary %>%
  group_by(Baseline) %>%
  summarise(
    n_combinations = n(),
    n_non_robust = sum(Sign_change, na.rm = TRUE),
    pct_non_robust = 100 * mean(Sign_change, na.rm = TRUE),
    mean_range = mean(LOO_range, na.rm = TRUE),
    median_range = median(LOO_range, na.rm = TRUE),
    .groups = "drop"
  )

cat("\n=== Summary by Baseline Modifier ===\n")
print(summary_by_baseline)

# Most non-robust combinations
non_robust_combinations <- loo_summary %>%
  filter(Sign_change == TRUE) %>%
  arrange(desc(LOO_range))

cat("\n=== Top 10 Non-Robust Combinations (by range) ===\n")
if(nrow(non_robust_combinations) > 0) {
  print(non_robust_combinations %>% 
          select(Outcome, Baseline, r_full, LOO_min, LOO_max, LOO_range) %>%
          head(10))
} else {
  cat("No non-robust combinations found.\n")
}

# =============================================================================
# Export results
# =============================================================================

# Export summary table
write_csv(loo_summary, 
          "results_loo_sensitivity_analysis_summary.csv")

# Export detailed LOO values
write_csv(loo_detailed, 
          "results_loo_sensitivity_analysis_detailed.csv")

# Export summary statistics
write_csv(summary_by_outcome, 
          "results_loo_sensitivity_summary_by_outcome.csv")
write_csv(summary_by_baseline, 
          "results_loo_sensitivity_summary_by_baseline.csv")

# Export non-robust combinations
if(nrow(non_robust_combinations) > 0) {
  write_csv(non_robust_combinations, 
            "results_loo_sensitivity_non_robust_combinations.csv")
}

# Save full results object (for potential future use)
saveRDS(loo_results, 
        "results_loo_sensitivity_analysis_full.rds")

cat("\n=== Results Exported ===\n")
cat("Files created:\n")
cat("1. LOO_Sensitivity_Analysis_Summary.csv - Main summary table\n")
cat("2. LOO_Sensitivity_Analysis_Detailed.csv - Detailed LOO values per patient\n")
cat("3. LOO_Sensitivity_Summary_by_Outcome.csv - Summary by outcome\n")
cat("4. LOO_Sensitivity_Summary_by_Baseline.csv - Summary by baseline modifier\n")
if(nrow(non_robust_combinations) > 0) {
  cat("5. LOO_Sensitivity_Non_Robust_Combinations.csv - Non-robust combinations\n")
}
cat("6. LOO_Sensitivity_Analysis_Full_Results.rds - Full results object\n")

cat("\n=== LOO Sensitivity Analysis Complete ===\n")

